// Copyright 2024 TF.Text Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "tensorflow_text/core/kernels/phrase_tokenizer.h"

#include <algorithm>
#include <iostream>
#include <ostream>
#include <string>
#include <vector>

#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow_text/core/kernels/whitespace_tokenizer_config_builder.h"

namespace tensorflow {
namespace text {

/*static*/ absl::StatusOr<PhraseTokenizer> PhraseTokenizer::Create(
    const void* config_flatbuffer) {
  PhraseTokenizer tokenizer;
  // `GetPhraseTokenizerConfig()` is autogenerated by flatbuffer.
  tokenizer.phrase_config_ = GetPhraseTokenizerConfig(config_flatbuffer);
  tokenizer.trie_ = absl::make_unique<sentencepiece::DoubleArrayTrie>(
      tokenizer.phrase_config_->vocab_trie()->nodes());
  tokenizer.prob_ = static_cast<float>(tokenizer.phrase_config_->prob()) / 100;
  const auto& ws_config = tokenizer.phrase_config_->whitespace_config();
  tokenizer.whitespace_config_str_ =
      absl::string_view(ws_config->c_str(), ws_config->size());
  tokenizer.whitespace_tokenizer_ = absl::make_unique<WhitespaceTokenizer>(
      WhitespaceTokenizerConfig(tokenizer.whitespace_config_str_));
  tokenizer.split_end_punctuation_ =
      tokenizer.phrase_config_->split_end_punctuation();
  return std::move(tokenizer);
}

void PhraseTokenizer::Tokenize(const absl::string_view input,
                               std::vector<std::string>* result_tokens,
                               std::vector<int>* result_token_ids) {
  // Word level information.
  std::vector<std::string> tokens;

  whitespace_tokenizer_->Tokenize(input, &tokens);

  // Loop through tokens, considering 1-level punctuations.
  std::string all_str;
  int n = tokens.size();
  for (int i = 0; i < n; i++) {
    if (tokens[i].empty()) {
      continue;
    }
    if (split_end_punctuation_) {
      bool contained_special_token = false;
      for (const auto& special_token : special_tokens_) {
        if (absl::EndsWith(tokens[i], special_token)) {
          // Eg: split "can't" into "can 't"
          all_str +=
              tokens[i].substr(0, tokens[i].size() - special_token.size());
          all_str += " ";
          all_str += special_token;
          contained_special_token = true;
          break;
        }
      }
      if (!contained_special_token) {
        all_str += tokens[i];
      }
    } else {
      all_str += tokens[i];
    }
    if (i < n - 1) {
      all_str += " ";
    }
  }

  FindPhraseTokens(all_str, result_tokens, result_token_ids);
}

void PhraseTokenizer::FindPhraseTokens(const std::string& cur_phrase,
                                       std::vector<std::string>* phrase_tokens,
                                       std::vector<int>* phrase_token_ids) {
  // Do a simple left to right search to tokenize the input text.
  int index = 0;
  while (index < cur_phrase.size()) {
    bool in_trie = false;
    int token_id = phrase_config_->unk_token_id();
    int length = 0;
    PhraseLookup(cur_phrase, index, &in_trie, &token_id, &length);
    if (!in_trie) {
      // fall back to using single token.
      std::size_t found = cur_phrase.find_first_of(' ', index);
      phrase_tokens->push_back(phrase_config_->unk_token()->str());
      phrase_token_ids->push_back(phrase_config_->unk_token_id());
      if (found == std::string::npos) {
        break;
      }
      index = found + 1;
    } else {
      // Found a phrase.
      phrase_tokens->push_back(cur_phrase.substr(index, length));
      phrase_token_ids->push_back(token_id);
      index += (length + 1);
    }
  }
}

void PhraseTokenizer::PhraseLookup(const std::string& token, int cur,
                                   bool* in_trie, int* emitted_phrase_id,
                                   int* emitted_phrase_length) {
  int matched_phrase_id = -1;
  int matched_phrase_length = 0;
  bool phrase_emitted = false;
  float prob = prob_;
  absl::BitGen* gen = &gen_;
  auto phrase_emit_func =
      [&token /*the input string*/,
       cur /*the current starting point for searching phrase*/,
       prob /*the probability to emit the current found phrase*/,
       in_trie /*whether a phrase in matched in the trie*/,
       emitted_phrase_id /*the token id of the emitted phrase*/,
       emitted_phrase_length /*the length of the emitted phrase*/,
       &matched_phrase_id /*the token id of the matched phrase*/,
       &matched_phrase_length /*the length of the matched phrase*/,
       &phrase_emitted /*whether the phrase is emitted or not*/,
       gen /*the random generator*/](
          const sentencepiece::DoubleArrayTrie::Match& m) {
        if (phrase_emitted || (cur + m.match_length < token.size() &&
                               token[cur + m.match_length] != ' ')) {
          // We should continue search without going through this function if:
          // 1: a phrase has already been emitted, or
          // 2: We located a phrase that split one single word.
          return;
        }

        matched_phrase_id = m.id;
        matched_phrase_length = m.match_length;
        *in_trie = true;
        if ((prob > 0) && absl::Bernoulli(*gen, prob)) {
          // Emit the current phrase.
          *emitted_phrase_id = m.id;
          *emitted_phrase_length = m.match_length;
          phrase_emitted = true;
        }
      };
  trie_->IteratePrefixMatches(
      sentencepiece::utils::string_view(token.data() + cur, token.size() - cur),
      phrase_emit_func);
  if (*in_trie && !phrase_emitted) {
    // We should use prev longest one as output as we prefer longer ones.
    *emitted_phrase_id = matched_phrase_id;
    *emitted_phrase_length = matched_phrase_length;
  }
}

absl::StatusOr<std::vector<std::string>> PhraseTokenizer::DetokenizeToTokens(
    const absl::Span<const int> input) const {
  std::vector<std::string> output_tokens;
  if (!phrase_config_->support_detokenization()) {
    return absl::FailedPreconditionError(
        "Detokenize function is only enabled when support_detokenization is "
        "true in the config flatbuffer. Please rebuild the model flatbuffer "
        "by setting support_detokenization=true.");
  }
  for (int id : input) {
    auto vocab = phrase_config_->vocab_array()->Get(id);
    output_tokens.emplace_back(vocab->string_view());
  }
  return output_tokens;
}

absl::StatusOr<std::string> PhraseTokenizer::Detokenize(
    const absl::Span<const int> input) const {
  SH_ASSIGN_OR_RETURN(std::vector<std::string> output_tokens,
                      DetokenizeToTokens(input));
  if (split_end_punctuation_) {
    std::string result;
    for (const auto& token : output_tokens) {
      if (special_tokens_.contains(token)) {
        result += token;
      } else {
        result += ((result.empty() ? "" : " ") + token);
      }
    }
    return result;
  } else {
    return absl::StrJoin(output_tokens, " ");
  }
}

}  // namespace text
}  // namespace tensorflow
